import os
import math
import argparse
import wandb
import time

import torch
import torch.optim as optim
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

import torchvision.models.resnet as resnet
#from model import resnet34, resnet101     # 自己写model.py
from multi_train_utils.train_eval_utils import train_one_epoch, evaluate

import sys
sys.path.append("..")
from custom_dataset.my_dataset import MyDataSet
from custom_dataset.utils import read_split_data

'''
单GPU训练(单机单卡)
'''
def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)

    # 启动wandb
    print('Start wandb, view at https://wandb.ai/')
    wandb.init(project='train_single_gpu', name=time.strftime('%m%d%H%M%S'))
    train_log = {}


    # 新建权重文件夹，保存权重
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_info, val_info, num_classes = read_split_data(args.data_path)
    train_images_path, train_images_label = train_info
    val_images_path, val_images_label = val_info

    # check num_classes
    assert args.num_classes == num_classes, "dataset num_classes: {}, input {}".format(args.num_classes,
                                                                                       num_classes)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    # 实例化验证数据集
    val_data_set = MyDataSet(images_path=val_images_path,
                             images_class=val_images_label,
                             transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_data_set,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_data_set.collate_fn)

    # 如果存在预训练权重则载入
    # model = resnet34(num_classes=args.num_classes).to(device)   #自己写model
    model = resnet.resnet34(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            # 载入预训练权重参数（有序字典）
            weights_dict = torch.load(args.weights, map_location=device)
            # 对比：模型的参数和载入的参数【不同参数名的参数个数】
            # 如果最后全连接层参数不一致，那么最后一层的参数也就不会载入
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            # 正式将参数载入模型中
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结最后一层权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外，其他权重全部冻结
            if "fc" not in name:
                para.requires_grad_(False)

    # 参数梯度（待训练）parameters grad
    pg = [p for p in model.parameters() if p.requires_grad]
    # 优化器：传入待学习的参数，学习率，动量，正则项
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # 学习率曲线
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    # 学习率调整方法
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        # 更新学习率
        scheduler.step()

        # validate
        sum_num = evaluate(model=model,
                           data_loader=val_loader,
                           device=device)
        # 预测准确率
        acc = sum_num / len(val_data_set)
        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["loss", "accuracy", "learning_rate"]

        # 写入wandb【一个epoch打印一次信息】
        train_log['mean_loss'] = mean_loss
        train_log['acc'] = acc
        train_log['lr'] = optimizer.param_groups[0]["lr"] 
        wandb.log(train_log)

        # 保存模型训练好的参数
        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    # 1.实例化一个 ArgumentParser 对象
    parser = argparse.ArgumentParser()
    # 2.添加输入参数
    # 分类类别的个数
    parser.add_argument('--num_classes', type=int, default=5)
    # 训练轮数
    parser.add_argument('--epochs', type=int, default=30)
    # 批量大小
    parser.add_argument('--batch-size', type=int, default=16)
    # 学习率
    parser.add_argument('--lr', type=float, default=0.001)
    # 倍率因子：学习率会逐步降低到最后的倍数
    parser.add_argument('--lrf', type=float, default=0.1)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="/home/lighthouse/gitee/image-processing/data_set/flower_data/flower_photos")

    # resnet34 官方权重下载地址
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    parser.add_argument('--weights', type=str, default='resNet34.pth',   # 可以下载放到同一位置目录下
                        help='initial weights path')
    # 是否冻结除全连接层之外的层，默认：不冻结，从头训练
    parser.add_argument('--freeze-layers', type=bool, default=False)
    # 使用设备，默认：cuda
    parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')

    # 3.args实例
    opt = parser.parse_args()

    main(opt)
